library(lubridate)
library(ape)
library(tidyverse)
library(pegas)

# read the sample dates. Some should fail to parse and that's OK

sample.dates <- read_csv("sample_dates.csv")
sample.dates$samplingDate <- parse_date_time(sample.dates$samplingDate, "%m/%d/%y")
sample.dates$decimalDate <- decimal_date(sample.dates$samplingDate)

tip.renaming <- read_csv("new_tip_names.csv")

sample.dates$new.tip <- tip.renaming$new.label[match(sample.dates$sequenceName, tip.renaming$tip.name)]
sample.dates$new.tip[is.na(sample.dates$new.tip)] <- sample.dates$sequenceName[is.na(sample.dates$new.tip)]
sample.dates$split.patient <- sapply(sample.dates$new.tip, function(x){
  unlist(strsplit(x, "_"))[1]
})

snp.alignment <- read.dna("alignment_postgubbins.fas", format="fasta")
snp.alignment <- snp.alignment[which(startsWith(rownames(snp.alignment), "T")),]
snp.alignment <- snp.alignment[which(!(rownames(snp.alignment) %in% c("T330_N01_C08", "T249_N03_C02", "T234_W02_C01", "T035_N04_C02"))),]

snp.dists <- dist.dna(snp.alignment, model="N", pairwise.deletion = T, as.matrix = T)

old.to.new <- function(old.name){
  if(old.name %in% tip.renaming$tip.name){
    tip.renaming$new.label[which(tip.renaming$tip.name==old.name)]
  } else {
    old.name
  }
}

rownames(snp.dists) <- sapply(rownames(snp.dists), function(x) old.to.new(x))
names(rownames(snp.dists)) <- rownames(snp.dists)
colnames(snp.dists) <- sapply(colnames(snp.dists), function(x) old.to.new(x))
names(colnames(snp.dists)) <- colnames(snp.dists)
alignment.patients <- sapply(colnames(snp.dists), function(x) unlist(strsplit(x, "_"))[1])

patients <- unique(alignment.patients)
patients <- patients[order(patients)]

earliest.tip.date <- sapply(patients, function(x){
  min(sample.dates$decimalDate[which(sample.dates$split.patient==x & !(sample.dates$new.tip %in% c("T330_N01_C08", "T249_N03_C02", "T234_W02_C01", "T035_N04_C02")))])
})


all.tips.by.patient <- lapply(1:length(patients), function(x){
  sample.dates$new.tip[which(sample.dates$split.patient == patients[x]
                             & !(sample.dates$new.tip %in% c("T330_N01_C08", "T249_N03_C02", "T234_W02_C01", "T035_N04_C02"))
                       )]
  
})

all.tips.first.date.by.patient <- lapply(1:length(patients), function(x){
  sample.dates$new.tip[which(sample.dates$split.patient == patients[x]
                             & !(sample.dates$new.tip %in% c("T330_N01_C08", "T249_N03_C02", "T234_W02_C01", "T035_N04_C02"))
                             & sample.dates$decimalDate == min(sample.dates$decimalDate[which(!(sample.dates$new.tip %in% c("T330_N01_C08", "T249_N03_C02", "T234_W02_C01", "T035_N04_C02")) & sample.dates$split.patient == patients[x])])
  )]
  
})

names(all.tips.by.patient) <- patients
names(all.tips.first.date.by.patient) <- patients


median.dist.to.closest.earlier.sample <- sapply(1:length(patients), function(x){

  a.host <- patients[x]
  a.date <- earliest.tip.date[x]
  
  # rows are this patient
  allowed.rows <- all.tips.by.patient[[a.host]]

  # columns are everyone else
  
  allowed.isolates <- sample.dates$new.tip[which(sample.dates$decimalDate < a.date 
                                                 & sample.dates$split.patient != a.host
                                                 & !(sample.dates$new.tip %in% c("T330_N01_C08", "T249_N03_C02", "T234_W02_C01", "T035_N04_C02")))]
  
  distances.of.interest <- snp.dists[allowed.rows, allowed.isolates, drop=F]
  if(ncol(distances.of.interest)>0){
    min.col <- which(distances.of.interest == min(distances.of.interest), arr.ind = T)[,"col"]
    median(distances.of.interest[,min.col])
  } else {
    NA
  }
})

median.dist.to.closest.earlier.sample.first.date.only <- sapply(1:length(patients), function(x){
  
  a.host <- patients[x]
  a.date <- earliest.tip.date[x]
  
  # rows are this patient
  allowed.rows <- all.tips.first.date.by.patient[[a.host]]
  
  # columns are everyone else
  
  allowed.isolates <- sample.dates$new.tip[which(sample.dates$decimalDate < a.date 
                                                 & sample.dates$split.patient != a.host
                                                 & !(sample.dates$new.tip %in% c("T330_N01_C08", "T249_N03_C02", "T234_W02_C01", "T035_N04_C02")))]
  
  distances.of.interest <- snp.dists[allowed.rows, allowed.isolates, drop=F]
  if(ncol(distances.of.interest)>0){
    min.col <- which(distances.of.interest == min(distances.of.interest), arr.ind = T)[,"col"]
    median(distances.of.interest[,min.col])
  } else {
    NA
  }
})

single.sampled.ids <- read_csv("trace_colonisation_patients.csv") %>% pull(id)

distances.df.all <- data.frame(host = patients, distance = median.dist.to.closest.earlier.sample, stringsAsFactors = F)
distances.df.all$trace <- distances.df.all$host %in% single.sampled.ids

distances.df.jimmied.all <- distances.df.all %>%
  mutate(big = distance > 50) %>%
  mutate(distance.for.placement = map2_dbl(distance, big, function(x, y) ifelse(y, 50 + 50*log(x/50)/log(16), x)))

transform <- function(x) 50 + 50*log(x/50)/log(16)
newbreaks <- sapply(seq(75, 800, by=25), transform)

ggplot(distances.df.jimmied.all %>% filter(!is.na(big)), aes(x=distance.for.placement, fill=trace)) +
  geom_dotplot(stackgroups = TRUE, binwidth=1, binpositions="all") +
  scale_y_continuous(NULL, breaks = NULL) +
  theme_bw() +
  xlab("Median SNP distance to closest earlier isolate") +
  scale_fill_brewer(palette="Set1",
                    labels = c("Non-trace", "Trace")) +
  theme(legend.key.height = unit(1.1,"line"), legend.title = element_blank()) + 
  scale_x_continuous(labels=c(0, 25, 50, 100, 200, 400, 800), limits = c(0, 100), minor_breaks = newbreaks, breaks = c(0, 25, 50, 62.5, 75, 87.5, 100)) +
  geom_vline(xintercept = 50, linetype = "dotted")
                                                                                                          
ggsave("Figure2.pdf", width=9, height=1.2)

distances.df.earliest <- data.frame(host = patients, distance = median.dist.to.closest.earlier.sample.first.date.only, stringsAsFactors = F)
distances.df.earliest$trace <- distances.df.earliest$host %in% single.sampled.ids

distances.df.jimmied.earliest <- distances.df.earliest %>%
  mutate(big = distance > 50) %>%
  mutate(distance.for.placement = map2_dbl(distance, big, function(x, y) ifelse(y, 50 + 50*log(x/50)/log(16), x)))

transform <- function(x) 50 + 50*log(x/50)/log(16)
newbreaks <- sapply(seq(75, 800, by=25), transform)

ggplot(distances.df.jimmied.earliest %>% filter(!is.na(big)), aes(x=distance.for.placement, fill=trace)) +
  geom_dotplot(stackgroups = TRUE, binwidth=1, binpositions="all") +
  scale_y_continuous(NULL, breaks = NULL) +
  theme_bw() +
  xlab("Median SNP distance to closest earlier isolate") +
  scale_fill_brewer(palette="Set1",
                    labels = c("Non-trace", "Trace")) +
  theme(legend.key.height = unit(1.1,"line"), legend.title = element_blank()) + 
  scale_x_continuous(labels=c(0, 25, 50, 100, 200, 400, 800), limits = c(0, 100), minor_breaks = newbreaks, breaks = c(0, 25, 50, 62.5, 75, 87.5, 100)) +
  geom_vline(xintercept = 50, linetype = "dotted")

ggsave("Figure2S1.pdf", width=9, height=1.2)

